import groovy.json.JsonSlurper

import java.net.http.HttpClient
import java.util.concurrent.CompletableFuture
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.locks.ReentrantLock
import java.util.function.Consumer

class OpenAISSEChunkingConsumer {
    public LinkedBlockingQueue<Chunk> queue
    private ReentrantLock eventLock
    private jsonSlurper
    private currentChunkContent = ""

    public OpenAISSEChunkingConsumer(JsonSlurper jsonSlurper) {
        this.queue = new LinkedBlockingQueue<Chunk>()
        this.eventLock = new ReentrantLock(true)
        this.jsonSlurper = jsonSlurper
    }

    public Consumer getChunkingConsumer() {
        def consumer = chunk -> {
            eventLock.lock()
            try {
                def chunkObj = jsonSlurper.parseText(chunk)
                if (chunkObj.object == "chat.completion.chunk" && chunkObj.choices.size()) {
                    if(chunkObj.choices[0].finish_reason) {
                        if(currentChunkContent) {
                            queue.add(new Chunk(currentChunkContent.replaceAll(/[\n\r]/, ' ')))
                        }
                        queue.add(new Chunk('', true, false, chunkObj.choices[0].finish_reason))
                    } else {
                        def content = chunkObj.choices[0].delta.content
                        if (content) {
                            currentChunkContent += content

                            if (currentChunkContent =~ /[\n\r]+\Z/) {
                                queue.add(new Chunk(currentChunkContent.replaceAll(/[\n\r]/, ' ')))
                                currentChunkContent = ""
                            }
                        }
                    }
                }
            } catch (Exception ex) {
                queue.add(new Chunk("", true, true, "[${chunk}] ${ex.message}"))
            } finally {
                eventLock.unlock()
            }
        }
    }

    //Mandatory method: Get next chunk from response. 
    public Chunk getNextChunk() {
        return queue.take()
    }
    
    // Mandatory method: Aggregate entire response in a single chunk
    public Chunk getAggregatedChunk() {
        def chunks = []
        while (true) {
            def chunk = queue.take()
            if (chunk?.isFinal) break
            chunks << chunk.content
        }
        return new Chunk(chunks.join('\n'), true)
    }
}